{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import sklearn.datasets\n",
    "import re\n",
    "import time\n",
    "import pickle\n",
    "from sklearn.cross_validation import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Model:\n",
    "    \n",
    "    def __init__(self, num_layers, size_layer, dimension_input, dimension_output, learning_rate):\n",
    "        \n",
    "        def lstm_cell():\n",
    "            return tf.nn.rnn_cell.LSTMCell(size_layer)\n",
    "        \n",
    "        self.rnn_cells = tf.nn.rnn_cell.MultiRNNCell([lstm_cell() for _ in range(num_layers)])\n",
    "        \n",
    "        self.X = tf.placeholder(tf.float32, [None, None, dimension_input])\n",
    "        self.Y = tf.placeholder(tf.float32, [None, dimension_output])\n",
    "        \n",
    "        # dropout 0.5\n",
    "        drop = tf.contrib.rnn.DropoutWrapper(self.rnn_cells, output_keep_prob = 0.5)\n",
    "        \n",
    "        self.outputs, self.last_state = tf.nn.dynamic_rnn(drop, self.X, dtype = tf.float32)\n",
    "        \n",
    "        self.rnn_W = tf.Variable(tf.random_normal((size_layer, dimension_output)))\n",
    "        self.rnn_B = tf.Variable(tf.random_normal([dimension_output]))\n",
    "        self.logits = tf.matmul(self.outputs[:, -1], self.rnn_W) + self.rnn_B\n",
    "        \n",
    "        self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = self.logits, labels = self.Y))\n",
    "        \n",
    "        # L2 normalized\n",
    "        l2 = sum(0.0005 * tf.nn.l2_loss(tf_var) for tf_var in tf.trainable_variables())\n",
    "        \n",
    "        self.cost += l2\n",
    "        \n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(self.cost)\n",
    "        \n",
    "        self.correct_pred = tf.equal(tf.argmax(self.logits, 1), tf.argmax(self.Y, 1))\n",
    "        \n",
    "        self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def clearstring(string):\n",
    "    string = re.sub('[^\\\"\\'A-Za-z0-9 ]+', '', string)\n",
    "    string = string.split(' ')\n",
    "    string = filter(None, string)\n",
    "    string = [y.strip() for y in string]\n",
    "    string = ' '.join(string)\n",
    "    return string\n",
    "\n",
    "# because of sklean.datasets read a document as a single element\n",
    "# so we want to split based on new line\n",
    "def separate_dataset(trainset):\n",
    "    datastring = []\n",
    "    datatarget = []\n",
    "    for i in range(len(trainset.data)):\n",
    "        data_ = trainset.data[i].split('\\n')\n",
    "        # python3, if python2, just remove list()\n",
    "        data_ = list(filter(None, data_))\n",
    "        for n in range(len(data_)):\n",
    "            data_[n] = clearstring(data_[n])\n",
    "        datastring += data_\n",
    "        for n in range(len(data_)):\n",
    "            datatarget.append(trainset.target[i])\n",
    "    return datastring, datatarget"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "trainset_data = sklearn.datasets.load_files(container_path = 'data', encoding = 'UTF-8')\n",
    "trainset_data.data, trainset_data.target = separate_dataset(trainset_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with open('dictionary_emotion.p', 'rb') as fopen:\n",
    "    dict_emotion = pickle.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "len_sentences = np.array([len(i.split()) for i in trainset_data.data])\n",
    "maxlen = np.ceil(len_sentences.mean()).astype('int')\n",
    "data_X = np.zeros((len(trainset_data.data), maxlen))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for i in range(data_X.shape[0]):\n",
    "    tokens = trainset_data.data[i].split()[:maxlen]\n",
    "    for no, text in enumerate(tokens[::-1]):\n",
    "        try:\n",
    "            data_X[i, -1 - no] = dict_emotion[text]\n",
    "        except:\n",
    "            continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_X, test_X, train_Y, test_Y = train_test_split(data_X, trainset_data.target, test_size = 0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0 , pass acc: 0 , current acc: 0.319702\n",
      "epoch: 1 , training loss:  1.94579729865 , train acc:  0.310604175672\n",
      "epoch: 2 , training loss:  1.83255662903 , train acc:  0.318160571614\n",
      "epoch: 2 , pass acc: 0.319702 , current acc: 0.333161\n",
      "epoch: 3 , training loss:  1.77422007702 , train acc:  0.324265083696\n",
      "epoch: 3 , pass acc: 0.333161 , current acc: 0.333641\n",
      "epoch: 4 , training loss:  1.73213813629 , train acc:  0.327627814925\n",
      "epoch: 4 , pass acc: 0.333641 , current acc: 0.336268\n",
      "epoch: 5 , training loss:  1.7015007427 , train acc:  0.329553660394\n",
      "epoch: 5 , pass acc: 0.336268 , current acc: 0.34169\n",
      "epoch: 6 , training loss:  1.67902636687 , train acc:  0.332163452845\n",
      "epoch: 7 , training loss:  1.66076388262 , train acc:  0.332547422169\n",
      "epoch: 8 , training loss:  1.64492078734 , train acc:  0.335175211558\n",
      "epoch: 9 , training loss:  1.63448477662 , train acc:  0.33533419872\n",
      "epoch: 9 , pass acc: 0.34169 , current acc: 0.343214\n",
      "epoch: 10 , training loss:  1.62506710166 , train acc:  0.336519104364\n",
      "epoch: 11 , training loss:  1.61758198246 , train acc:  0.338414952711\n",
      "epoch: 12 , training loss:  1.61130492262 , train acc:  0.338459948881\n",
      "epoch: 12 , pass acc: 0.343214 , current acc: 0.34385\n",
      "epoch: 13 , training loss:  1.60589425657 , train acc:  0.339941831052\n",
      "epoch: 13 , pass acc: 0.34385 , current acc: 0.344965\n",
      "epoch: 14 , training loss:  1.60230706298 , train acc:  0.340478788254\n",
      "epoch: 14 , pass acc: 0.344965 , current acc: 0.345577\n",
      "epoch: 15 , training loss:  1.59823704796 , train acc:  0.340886755169\n",
      "epoch: 15 , pass acc: 0.345577 , current acc: 0.345721\n",
      "epoch: 16 , training loss:  1.59479282815 , train acc:  0.341813681268\n",
      "epoch: 17 , training loss:  1.59197722076 , train acc:  0.341921672327\n",
      "epoch: 18 , training loss:  1.59029727166 , train acc:  0.342329639815\n",
      "epoch: 19 , training loss:  1.58821955514 , train acc:  0.342629615923\n",
      "epoch: 19 , pass acc: 0.345721 , current acc: 0.346069\n",
      "epoch: 20 , training loss:  1.58590729023 , train acc:  0.342992587703\n",
      "epoch: 21 , training loss:  1.58476289191 , train acc:  0.343181571868\n",
      "epoch: 21 , pass acc: 0.346069 , current acc: 0.346093\n",
      "epoch: 22 , training loss:  1.58317788223 , train acc:  0.343211568999\n",
      "epoch: 22 , pass acc: 0.346093 , current acc: 0.346129\n",
      "epoch: 23 , training loss:  1.58223928652 , train acc:  0.343937511304\n",
      "epoch: 24 , training loss:  1.581324069 , train acc:  0.344210488836\n",
      "epoch: 25 , training loss:  1.5799498498 , train acc:  0.344093499385\n",
      "epoch: 25 , pass acc: 0.346129 , current acc: 0.346681\n",
      "epoch: 26 , training loss:  1.57913519892 , train acc:  0.343271565044\n",
      "epoch: 27 , training loss:  1.57894150871 , train acc:  0.344144495203\n",
      "epoch: 28 , training loss:  1.57819363323 , train acc:  0.343934511987\n",
      "epoch: 28 , pass acc: 0.346681 , current acc: 0.346693\n",
      "epoch: 29 , training loss:  1.57763400238 , train acc:  0.343808521838\n",
      "epoch: 30 , training loss:  1.57694988997 , train acc:  0.344129496589\n",
      "epoch: 31 , training loss:  1.57666215697 , train acc:  0.344261485212\n",
      "epoch: 32 , training loss:  1.57587428501 , train acc:  0.34387451627\n",
      "epoch: 33 , training loss:  1.57585709163 , train acc:  0.344612457303\n",
      "epoch: 34 , training loss:  1.57532724573 , train acc:  0.344135495411\n",
      "epoch: 34 , pass acc: 0.346693 , current acc: 0.346789\n",
      "epoch: 35 , training loss:  1.57491903502 , train acc:  0.345131415544\n",
      "epoch: 36 , training loss:  1.57476220309 , train acc:  0.344660453096\n",
      "epoch: 37 , training loss:  1.57431731845 , train acc:  0.344441471119\n",
      "epoch: 38 , training loss:  1.57403657778 , train acc:  0.344852438069\n",
      "epoch: 39 , training loss:  1.57376307631 , train acc:  0.344438471603\n",
      "epoch: 40 , training loss:  1.57360743953 , train acc:  0.344786443283\n",
      "epoch: 40 , pass acc: 0.346789 , current acc: 0.347029\n",
      "epoch: 41 , training loss:  1.57314552906 , train acc:  0.344798442907\n",
      "epoch: 42 , training loss:  1.57317841783 , train acc:  0.344189491394\n",
      "epoch: 42 , pass acc: 0.347029 , current acc: 0.347173\n",
      "epoch: 43 , training loss:  1.57290436129 , train acc:  0.344063501432\n",
      "epoch: 44 , training loss:  1.57281243183 , train acc:  0.344810441619\n",
      "epoch: 45 , training loss:  1.57236606471 , train acc:  0.344558461326\n",
      "epoch: 46 , training loss:  1.5721133444 , train acc:  0.344099498304\n",
      "epoch: 47 , training loss:  1.57229086386 , train acc:  0.344600458682\n",
      "epoch: 48 , training loss:  1.57197561197 , train acc:  0.344162493574\n",
      "epoch: 49 , training loss:  1.57183423396 , train acc:  0.344378476132\n",
      "epoch: 50 , training loss:  1.57176156109 , train acc:  0.345146414963\n",
      "epoch: 50 , pass acc: 0.347173 , current acc: 0.347736\n",
      "epoch: 51 , training loss:  1.5715926853 , train acc:  0.344750446068\n",
      "epoch: 52 , training loss:  1.57149430833 , train acc:  0.344351478034\n",
      "epoch: 53 , training loss:  1.57140828737 , train acc:  0.344828440114\n",
      "epoch: 54 , training loss:  1.57116867249 , train acc:  0.34527540481\n",
      "epoch: 55 , training loss:  1.5710449117 , train acc:  0.344480468128\n",
      "epoch: 56 , training loss:  1.5708096832 , train acc:  0.34442647244\n",
      "epoch: 57 , training loss:  1.57101057381 , train acc:  0.344888435504\n",
      "epoch: 57 , pass acc: 0.347736 , current acc: 0.347988\n",
      "epoch: 58 , training loss:  1.57068786366 , train acc:  0.34474744703\n",
      "epoch: 59 , training loss:  1.57081300398 , train acc:  0.344450469998\n",
      "epoch: 60 , training loss:  1.57053630458 , train acc:  0.344732448056\n",
      "epoch: 61 , training loss:  1.57055208844 , train acc:  0.344756446065\n",
      "epoch: 62 , training loss:  1.57033461208 , train acc:  0.345179411793\n",
      "epoch: 63 , training loss:  1.57014844279 , train acc:  0.344972428452\n",
      "epoch: 64 , training loss:  1.56999939826 , train acc:  0.345104418309\n",
      "epoch: 65 , training loss:  1.57010611486 , train acc:  0.344864437055\n",
      "epoch: 66 , training loss:  1.57011491876 , train acc:  0.344999426519\n",
      "epoch: 67 , training loss:  1.56992818333 , train acc:  0.34507442078\n",
      "epoch: 68 , training loss:  1.56980569372 , train acc:  0.344510465463\n",
      "epoch: 69 , training loss:  1.56973410537 , train acc:  0.344651454195\n",
      "epoch: 70 , training loss:  1.56976708304 , train acc:  0.34502942402\n",
      "epoch: 71 , training loss:  1.56957211217 , train acc:  0.344354478858\n",
      "epoch: 72 , training loss:  1.56946978537 , train acc:  0.345239407488\n",
      "epoch: 73 , training loss:  1.56935346251 , train acc:  0.344846438861\n",
      "epoch: 74 , training loss:  1.56926864815 , train acc:  0.344564461275\n",
      "epoch: 75 , training loss:  1.56929268564 , train acc:  0.345266404707\n",
      "epoch: 76 , training loss:  1.56907495159 , train acc:  0.344702450399\n",
      "epoch: 77 , training loss:  1.56922572671 , train acc:  0.344630455659\n",
      "epoch: 78 , training loss:  1.56905308527 , train acc:  0.345173412831\n",
      "epoch: 79 , training loss:  1.56924159172 , train acc:  0.34504442315\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-3a6fc8ecede3>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     18\u001b[0m             \u001b[0mbatch_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_Y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1.0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m         _, loss = sess.run([model.optimizer, model.cost], feed_dict = {model.X: batch_x, \n\u001b[0;32m---> 20\u001b[0;31m                                                                         model.Y: batch_y})\n\u001b[0m\u001b[1;32m     21\u001b[0m         \u001b[0mtrain_acc\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccuracy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbatch_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mY\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbatch_y\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     22\u001b[0m         \u001b[0mtrain_loss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m    787\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    788\u001b[0m       result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 789\u001b[0;31m                          run_metadata_ptr)\n\u001b[0m\u001b[1;32m    790\u001b[0m       \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    791\u001b[0m         \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m    995\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    996\u001b[0m       results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m--> 997\u001b[0;31m                              feed_dict_string, options, run_metadata)\n\u001b[0m\u001b[1;32m    998\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    999\u001b[0m       \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m   1130\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1131\u001b[0m       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,\n\u001b[0;32m-> 1132\u001b[0;31m                            target_list, options, run_metadata)\n\u001b[0m\u001b[1;32m   1133\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1134\u001b[0m       return self._do_call(_prun_fn, self._session, handle, feed_dict,\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m   1137\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1138\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1139\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1140\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1141\u001b[0m       \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m   1119\u001b[0m         return tf_session.TF_Run(session, options,\n\u001b[1;32m   1120\u001b[0m                                  \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1121\u001b[0;31m                                  status, run_metadata)\n\u001b[0m\u001b[1;32m   1122\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1123\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msession\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import os\n",
    "tf.reset_default_graph()\n",
    "model = Model(3, 128, 1, len(trainset_data.target_names), 0.0001)\n",
    "sess = tf.InteractiveSession()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "saver = tf.train.Saver(tf.global_variables())\n",
    "EARLY_STOPPING, CURRENT_CHECKPOINT, CURRENT_ACC, EPOCH = 100, 0, 0, 0\n",
    "batch_size = 120\n",
    "while True:\n",
    "    if CURRENT_CHECKPOINT == EARLY_STOPPING:\n",
    "        print('break epoch:', EPOCH)\n",
    "        break\n",
    "    train_acc, train_loss = 0, 0\n",
    "    for n in range(0, (train_X.shape[0] // batch_size) * batch_size, batch_size):\n",
    "        batch_x = np.expand_dims(train_X[n: n+batch_size, :], axis=2)\n",
    "        batch_y = np.zeros((batch_size, len(trainset_data.target_names)))\n",
    "        for k in range(batch_size):\n",
    "            batch_y[k, train_Y[n+k]] = 1.0\n",
    "        _, loss = sess.run([model.optimizer, model.cost], feed_dict = {model.X: batch_x, \n",
    "                                                                        model.Y: batch_y})\n",
    "        train_acc += sess.run(model.accuracy, feed_dict = {model.X: batch_x, model.Y: batch_y})\n",
    "        train_loss += loss\n",
    "    batch_y = np.zeros((test_X.shape[0], len(trainset_data.target_names)))\n",
    "    for k in range(test_X.shape[0]):\n",
    "        batch_y[k, test_Y[k]] = 1.0 \n",
    "    TEST_COST = sess.run(model.cost, feed_dict = {model.X: np.expand_dims(test_X, axis=2), model.Y: batch_y})\n",
    "    TEST_ACC = sess.run(model.accuracy, feed_dict = {model.X: np.expand_dims(test_X, axis=2), model.Y: batch_y})\n",
    "    train_loss /= (train_X.shape[0] // batch_size)\n",
    "    train_acc /= (train_X.shape[0] // batch_size)\n",
    "    if TEST_ACC > CURRENT_ACC:\n",
    "        print('epoch:', EPOCH, ', pass acc:', CURRENT_ACC, ', current acc:', TEST_ACC)\n",
    "        CURRENT_ACC = TEST_ACC\n",
    "        saver.save(sess, os.getcwd() + \"/model-rnn.ckpt\")\n",
    "    else:\n",
    "        CURRENT_CHECKPOINT += 1\n",
    "    EPOCH += 1\n",
    "    print('epoch:', EPOCH, ', training loss: ', train_loss, ', train acc: ', train_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
